knitr::opts_chunk$set(fig.align="center")
library(rstanarm)
library(tidyverse)
library(tidybayes)
library(modelr)
library(ggplot2)
library(magrittr)
library(emmeans)
library(bayesplot)
library(brms)
library(gganimate)
theme_set(theme_light())
source('helper_functions.R')
In our experiment, we used a visualization recommendation algorithm (composed of one search algorithm and one oracle algorithm) to generate visualizations for the user on one of two datasets. We then measured the user’s accuracy on two tasks: Find Extremum and Retrieve Value.
Given a search algorithm (bfs or dfs), an oracle (compassql or dziban), and a dataset (birdstrikes or movies), we would like to predict a user’s chance of answering the Find Extremum task and the Retrieve Value tasks correctly. In addition, we would like to know if the choice of search algorithm and oracle has any meaningful impact on a user’s accuracy for these two tasks.
accuracy_data = read.csv('split_by_participant_groups/accuracy.csv')
accuracy_data$oracle = as.factor(accuracy_data$oracle)
accuracy_data$search = as.factor(accuracy_data$search)
accuracy_data$dataset = as.factor(accuracy_data$dataset)
models <- list()
draw_data <- list()
search_differences <- list()
oracle_differences <- list()
seed = 12
We derived a very weakly informative prior from our pilot studies
model <- brm(
bf(
accuracy ~ 0 + Intercept + oracle * search * dataset + task + participant_group + (1 | participant_id)
),
data = accuracy_data,
prior = c(prior(normal(0.8, .1), class = "b", coef = "Intercept"),
prior(normal(0, 2.5), class = "b")),
family = bernoulli(link = "logit"),
warmup = 500,
iter = 3000,
chains = 2,
cores = 2,
control = list(adapt_delta = 0.9),
seed = seed,
file = "models/accuracy"
)
In the summary table, we want to see Rhat values close to 1.0 and Bulk_ESS in the thousands.
summary(model)
## Family: bernoulli
## Links: mu = logit
## Formula: accuracy ~ 0 + Intercept + oracle * search * dataset + task + participant_group + (1 | participant_id)
## Data: accuracy_data (Number of observations: 132)
## Samples: 2 chains, each with iter = 3000; warmup = 500; thin = 1;
## total post-warmup samples = 5000
##
## Group-Level Effects:
## ~participant_id (Number of levels: 66)
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept) 1.42 0.93 0.07 3.49 1.00 852 1751
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat
## Intercept 0.83 0.10 0.64 1.03 1.00
## oracledziban 2.68 1.17 0.63 5.14 1.00
## searchdfs 2.03 1.08 0.08 4.31 1.00
## datasetmovies 0.95 0.96 -0.78 3.03 1.00
## task2.RetrieveValue 0.72 0.63 -0.46 1.98 1.00
## participant_groupstudent 0.12 0.76 -1.20 1.79 1.00
## oracledziban:searchdfs -2.69 1.51 -5.65 0.29 1.00
## oracledziban:datasetmovies -0.87 1.52 -3.77 2.18 1.00
## searchdfs:datasetmovies 1.17 1.61 -1.97 4.38 1.00
## oracledziban:searchdfs:datasetmovies -1.34 1.87 -5.15 2.36 1.00
## Bulk_ESS Tail_ESS
## Intercept 9485 3600
## oracledziban 3025 2928
## searchdfs 3393 3639
## datasetmovies 3356 2767
## task2.RetrieveValue 7854 3716
## participant_groupstudent 3025 2087
## oracledziban:searchdfs 3505 3384
## oracledziban:datasetmovies 3435 3863
## searchdfs:datasetmovies 4827 3938
## oracledziban:searchdfs:datasetmovies 4819 3803
##
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
Trace plots help us check whether there is evidence of non-convergence for model.
plot(model)
In our pairs plots, we want to make sure we don’t have highly correlated parameters (highly correlated parameters means that our model has difficulty differentiating the effect of such parameters).
pairs(
model,
pars = c("b_Intercept",
"b_datasetmovies",
"b_oracledziban",
"b_searchdfs",
"b_task2.RetrieveValue"),
fixed = TRUE
)
pp_check(model, type = "dens_overlay", nsamples = 100)
A confusion matrix can be used to check our correct classification rate (a useful measure to see how well our model fits our data).
pred <- predict(model, type = "response")
pred <- if_else(pred[,1] > 0.5, 1, 0)
confusion_matrix <- table(pred, pull(accuracy_data, accuracy))
confusion_matrix
##
## pred 0 1
## 1 10 122
Visualization of parameter effects via draws from our model posterior. The thicker line represents the 95% credible interval, while the thinner, longer line represents the 50% credible interval.
draw_data <- accuracy_data %>%
add_fitted_draws(model, seed = seed, re_formula = NA, scale = "response") %>%
group_by(search, oracle, dataset, task, .draw)
plot_data <- draw_data
plot_data$oracle<- gsub('compassql', 'CompassQL', plot_data$oracle)
plot_data$oracle<- gsub('dziban', 'Dziban', plot_data$oracle)
plot_data$search<- gsub('bfs', 'BFS', plot_data$search)
plot_data$search<- gsub('dfs', 'DFS', plot_data$search)
plot_data$condition <- paste(plot_data$oracle, plot_data$search, sep=" + ")
draw_plot <- posterior_draws_plot(plot_data, "dataset", TRUE, "Predicted accuracy (p_correct)", "Oracle/Search Combination")
draw_plot
Since the credible intervals on our plot overlap, we can use mean_qi to get the numeric boundaries for the different intervals.
fit_info <- draw_data %>% group_by(search, oracle, dataset, task) %>% mean_qi(.value, .width = c(.95, .5))
fit_info
## # A tibble: 32 x 10
## # Groups: search, oracle, dataset [8]
## search oracle dataset task .value .lower .upper .width .point .interval
## <fct> <fct> <fct> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 bfs compas… birdstr… 1. Find… 0.698 0.440 0.917 0.95 mean qi
## 2 bfs compas… birdstr… 2. Retr… 0.809 0.538 0.966 0.95 mean qi
## 3 bfs compas… movies 1. Find… 0.827 0.523 0.986 0.95 mean qi
## 4 bfs compas… movies 2. Retr… 0.897 0.656 0.994 0.95 mean qi
## 5 bfs dziban birdstr… 1. Find… 0.953 0.806 0.998 0.95 mean qi
## 6 bfs dziban birdstr… 2. Retr… 0.973 0.876 0.999 0.95 mean qi
## 7 bfs dziban movies 1. Find… 0.949 0.780 0.999 0.95 mean qi
## 8 bfs dziban movies 2. Retr… 0.973 0.873 1.00 0.95 mean qi
## 9 dfs compas… birdstr… 1. Find… 0.920 0.711 0.996 0.95 mean qi
## 10 dfs compas… birdstr… 2. Retr… 0.955 0.821 0.998 0.95 mean qi
## # … with 22 more rows
## Saving 7 x 5 in image
predictive_data <- accuracy_data %>%
add_fitted_draws(model, seed = seed, re_formula = NA, scale = "response")
Differences in search algorithms:
search_differences <- expected_diff_in_mean_plot(predictive_data, "search", "Difference in Mean Accuracy (p_correct)", "Task", "dataset")
## `summarise()` regrouping output by 'search', 'task', 'dataset' (override with `.groups` argument)
search_differences$plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
search_differences$intervals
## # A tibble: 8 x 9
## # Groups: search, dataset [2]
## search dataset task difference .lower .upper .width .point .interval
## <chr> <fct> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 bfs - … birdstri… 1. Find… -0.0886 -0.224 0.0698 0.95 mean qi
## 2 bfs - … birdstri… 2. Retr… -0.0609 -0.192 0.0338 0.95 mean qi
## 3 bfs - … movies 1. Find… -0.0510 -0.208 0.0888 0.95 mean qi
## 4 bfs - … movies 2. Retr… -0.0313 -0.146 0.0540 0.95 mean qi
## 5 bfs - … birdstri… 1. Find… -0.0886 -0.136 -0.0477 0.5 mean qi
## 6 bfs - … birdstri… 2. Retr… -0.0609 -0.0909 -0.0245 0.5 mean qi
## 7 bfs - … movies 1. Find… -0.0510 -0.0929 -0.00687 0.5 mean qi
## 8 bfs - … movies 2. Retr… -0.0313 -0.0543 -0.00349 0.5 mean qi
Differences in oracle:
oracle_differences <- expected_diff_in_mean_plot(predictive_data, "oracle", "Difference in Mean Accuracy (p_correct)", "Task", "dataset")
## `summarise()` regrouping output by 'oracle', 'task', 'dataset' (override with `.groups` argument)
oracle_differences$plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
oracle_differences$intervals
## # A tibble: 8 x 9
## # Groups: oracle, dataset [2]
## oracle dataset task difference .lower .upper .width .point .interval
## <chr> <fct> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dziban - … birdstr… 1. Find… 0.121 -0.0291 0.272 0.95 mean qi
## 2 dziban - … birdstr… 2. Retr… 0.0789 -0.0162 0.215 0.95 mean qi
## 3 dziban - … movies 1. Find… 0.0183 -0.135 0.170 0.95 mean qi
## 4 dziban - … movies 2. Retr… 0.0132 -0.0801 0.121 0.95 mean qi
## 5 dziban - … birdstr… 1. Find… 0.121 0.0762 0.166 0.5 mean qi
## 6 dziban - … birdstr… 2. Retr… 0.0789 0.0390 0.110 0.5 mean qi
## 7 dziban - … movies 1. Find… 0.0183 -0.0212 0.0592 0.5 mean qi
## 8 dziban - … movies 2. Retr… 0.0132 -0.0102 0.0348 0.5 mean qi
Differences in participant group (student vs professional):
participant_group_differences <- expected_diff_in_mean_plot(predictive_data, "participant_group", "Difference in Mean Accuracy (p_correct)", "Task", NULL)
## `summarise()` regrouping output by 'participant_group', 'task', 'dataset' (override with `.groups` argument)
participant_group_differences$plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
participant_group_differences$intervals
## # A tibble: 4 x 8
## # Groups: participant_group [1]
## participant_group task difference .lower .upper .width .point .interval
## <chr> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 student - profess… 1. Find … -0.000481 -0.120 0.113 0.95 mean qi
## 2 student - profess… 2. Retri… 0.00170 -0.0758 0.0926 0.95 mean qi
## 3 student - profess… 1. Find … -0.000481 -0.0345 0.0339 0.5 mean qi
## 4 student - profess… 2. Retri… 0.00170 -0.0191 0.0201 0.5 mean qi
participant_group_differences_dataset <- expected_diff_in_mean_plot(predictive_data, "participant_group", "Difference in Mean Accuracy (p_correct)", "Task", "dataset")
## `summarise()` regrouping output by 'participant_group', 'task', 'dataset' (override with `.groups` argument)
participant_group_differences_dataset$plot